import argparse
import os

import yaml


def prompt_func(mode, lang):
    prompt_map = {
        "prompt_1": "Given the categories technology, business, politics, sports, health, entertainment, or religion; what category does the text: '{{headline}}' belong to: \n\n",
        "prompt_2": f"Does this {lang} topic; "
        "'{{headline}}' belong to one of the following categories: technology, business, politics, sports, health, entertainment, or religion? category only\n\n",
        "prompt_3": f"You are an assistant able to classify topics in texts. \n\n"
        f"Given the categories technology, religion, politics, sports, health, entertainment, or business; what is "
        f"the topic of the {lang} statement below? Return only the category. "
        "\n\ntext: {{headline}} \\category:\n\n",
        "prompt_4": "Label the following text as technology, religion, politics, sports, health, entertainment, or geography. Provide only the category as your "
        "response. \n\ntext: {{headline}} \\category: \n\n",
        "prompt_5": f"You are tasked with performing topic classification on the following {lang} text. "
        f"For each input, classify the topic as technology, business, politics, sports, health, entertainment, or religion. "
        f"Use the following guidelines: \n\n "
        f"technology: The text discusses scientific discoveries, technological advancements, or related topics. \n"
        f"politics: The text covers political events, policies, or related topics. \n"
        f"sports: The text talks about sports events, athletes, or related topics. \n"
        f"health: The text addresses health issues, medical advancements, or related topics. \n"
        f"entertainment: The text pertains to movies, music, celebrities, or related topics. \n"
        f"religion: The text talks about relgions, religious institutions and beliefs or related topics. \n\n"
        f"business: The text covers economy, business, or related topics. \n\n"
        f"If the text contains multiple topics, choose the dominant topic. "
        f"For ambiguous or unclear topics, select the category that best reflects the overall content. "
        "Please provide a single classification for each input.\n\ntext: {{headline}} \\category: \n\n",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "amh": "Amharic",
        "eng": "English",
        "fra": "French",
        "hau": "Hausa",
        "ibo": "Igbo",
        "lin": "Lingala",
        "lug": "Luganda",
        "orm": "Afaan Oromoo",
        "pcm": "Nigerian Pidgin",
        "run": "Kirundi",
        "sna": "Shona",
        "som": "Somali",
        "swa": "Swahili",
        "tir": "Tigrinya",
        "xho": "Xhosa",
        "yor": "Yoruba",
    }

    for lang in languages.keys():
        try:
            file_name = f"masakhanews_{lang}.yaml"
            task_name = f"masakhanews_{lang}_{mode}"
            yaml_template = "masakhanews"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang,
                "doc_to_text": prompt_func(mode, languages[lang]),
            }
            file_path = os.path.join(output_dir, mode)
            os.makedirs(file_path, exist_ok=True)

            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )

    PROMPT_CHOICES = ["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"]
    parser.add_argument(
        "--mode",
        nargs="*",
        default=PROMPT_CHOICES,
        choices=PROMPT_CHOICES,
        help="Prompt number(s)",
    )
    args = parser.parse_args()

    for mode in args.mode:
        gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=mode)


if __name__ == "__main__":
    main()
